// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package ssh

import (
	"context"
	"crypto"
	"crypto/ecdsa"
	"crypto/ed25519"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"fmt"
	"io"

	multierror "github.com/hashicorp/go-multierror"
	"github.com/hashicorp/vault/builtin/logical/ssh/managed_key"
	"github.com/hashicorp/vault/sdk/framework"
	"github.com/hashicorp/vault/sdk/helper/cryptoutil"
	"github.com/hashicorp/vault/sdk/logical"
	"github.com/mikesmitty/edkey"
	"golang.org/x/crypto/ssh"
)

const (
	caPublicKey                       = "ca_public_key"
	caPrivateKey                      = "ca_private_key"
	caPublicKeyStoragePath            = "config/ca_public_key"
	caPublicKeyStoragePathDeprecated  = "public_key"
	caPrivateKeyStoragePath           = "config/ca_private_key"
	caPrivateKeyStoragePathDeprecated = "config/ca_bundle"
	caManagedKeyStoragePath           = "config/ca_managed_key"
)

type keyStorageEntry struct {
	Key string `json:"key" structs:"key" mapstructure:"key"`
}

type managedKeyStorageEntry struct {
	KeyId     managed_key.UUIDKey `json:"key_id" structs:"key_id" mapstructure:"key_id"`
	KeyName   managed_key.NameKey `json:"key_name" structs:"key_name" mapstructure:"key_name"`
	PublicKey string              `json:"public_key" structs:"public_key" mapstructure:"public_key"`
}

func pathConfigCA(b *backend) *framework.Path {
	return &framework.Path{
		Pattern: "config/ca",

		DisplayAttrs: &framework.DisplayAttributes{
			OperationPrefix: operationPrefixSSH,
		},

		Fields: map[string]*framework.FieldSchema{
			"private_key": {
				Type:        framework.TypeString,
				Description: `Private half of the SSH key that will be used to sign certificates.`,
			},
			"public_key": {
				Type:        framework.TypeString,
				Description: `Public half of the SSH key that will be used to sign certificates.`,
			},
			"generate_signing_key": {
				Type:        framework.TypeBool,
				Description: `Generate SSH key pair internally rather than use the private_key and public_key fields. If managed key config is provided, this field is ignored.`,
				Default:     true,
			},
			"key_type": {
				Type:        framework.TypeString,
				Description: `Specifies the desired key type when generating; could be a OpenSSH key type identifier (ssh-rsa, ecdsa-sha2-nistp256, ecdsa-sha2-nistp384, ecdsa-sha2-nistp521, or ssh-ed25519) or an algorithm (rsa, ec, ed25519).`,
				Default:     "ssh-rsa",
			},
			"key_bits": {
				Type:        framework.TypeInt,
				Description: `Specifies the desired key bits when generating variable-length keys (such as when key_type="ssh-rsa") or which NIST P-curve to use when key_type="ec" (256, 384, or 521).`,
				Default:     0,
			},
			"managed_key_name": {
				Type:        framework.TypeString,
				Description: `The name of the managed key to use. When using a managed key, this field or managed_key_id is required.`,
			},
			"managed_key_id": {
				Type:        framework.TypeString,
				Description: `The id of the managed key to use. When using a managed key, this field or managed_key_name is required.`,
			},
		},

		Operations: map[logical.Operation]framework.OperationHandler{
			logical.UpdateOperation: &framework.PathOperation{
				Callback: b.pathConfigCAUpdate,
				DisplayAttrs: &framework.DisplayAttributes{
					OperationVerb:   "configure",
					OperationSuffix: "ca",
				},
			},
			logical.DeleteOperation: &framework.PathOperation{
				Callback: b.pathConfigCADelete,
				DisplayAttrs: &framework.DisplayAttributes{
					OperationSuffix: "ca-configuration",
				},
			},
			logical.ReadOperation: &framework.PathOperation{
				Callback: b.pathConfigCARead,
				DisplayAttrs: &framework.DisplayAttributes{
					OperationSuffix: "ca-configuration",
				},
			},
			logical.RecoverOperation: &framework.PathOperation{
				Callback: b.pathConfigCARecover,
			},
		},

		HelpSynopsis: `Set the SSH private key used for signing certificates.`,
		HelpDescription: `This sets the CA information used for certificates generated by this
by this mount. The fields must be in the standard private and public SSH format.

For security reasons, the private key cannot be retrieved later.

Read operations will return the public key, if already stored/generated.`,
	}
}

func (b *backend) pathConfigCARead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	// prevent migration from deprecated paths on snapshot read as writes to a loaded snapshot storage are forbidden
	allowMigration := !req.IsSnapshotReadOrList()
	publicKey, err := getCAPublicKey(ctx, req.Storage, allowMigration)
	if err != nil {
		return nil, fmt.Errorf("failed to read CA public key: %w", err)
	}

	if publicKey == "" {
		return logical.ErrorResponse("keys haven't been configured yet"), nil
	}

	response := &logical.Response{
		Data: map[string]interface{}{
			"public_key": publicKey,
		},
	}

	return response, nil
}

func (b *backend) pathConfigCADelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	if err := req.Storage.Delete(ctx, caPrivateKeyStoragePath); err != nil {
		return nil, err
	}
	if err := req.Storage.Delete(ctx, caPrivateKeyStoragePathDeprecated); err != nil {
		return nil, err
	}
	if err := req.Storage.Delete(ctx, caPublicKeyStoragePath); err != nil {
		return nil, err
	}
	if err := req.Storage.Delete(ctx, caPublicKeyStoragePathDeprecated); err != nil {
		return nil, err
	}
	if err := req.Storage.Delete(ctx, caManagedKeyStoragePath); err != nil {
		return nil, err
	}
	return nil, nil
}

func readStoredKeyEntry(ctx context.Context, storage logical.Storage, keyType string, allowMigration bool) (*logical.StorageEntry, error) {
	var path, deprecatedPath string
	switch keyType {
	case caPrivateKey:
		path = caPrivateKeyStoragePath
		deprecatedPath = caPrivateKeyStoragePathDeprecated
	case caPublicKey:
		path = caPublicKeyStoragePath
		deprecatedPath = caPublicKeyStoragePathDeprecated
	default:
		return nil, fmt.Errorf("unrecognized key type %q", keyType)
	}

	entry, err := storage.Get(ctx, path)
	if err != nil {
		return nil, fmt.Errorf("failed to read CA key of type %q: %w", keyType, err)
	}

	if entry == nil {
		// If the entry is not found, look at an older path. If found, upgrade
		// it.
		entry, err = storage.Get(ctx, deprecatedPath)
		if err != nil {
			return nil, err
		}

		if entry != nil {
			// modify entry variable, both for possible migration and also to comply with the expected JSON entry for the caller
			entry, err = logical.StorageEntryJSON(path, keyStorageEntry{
				Key: string(entry.Value),
			})
			if err != nil {
				return nil, err
			}
			// migrations are disable on recover, as we can't write to the loaded snapshot storage
			if allowMigration {
				if err := storage.Put(ctx, entry); err != nil {
					return nil, err
				}
				if err = storage.Delete(ctx, deprecatedPath); err != nil {
					return nil, err
				}
			}
		}
	}
	return entry, nil
}

// readStoredKey reads a key from storage, returning nil if not found.
// ignore-nil-nil-function-check
func readStoredKey(ctx context.Context, storage logical.Storage, keyType string, allowMigration bool) (*keyStorageEntry, error) {
	entry, err := readStoredKeyEntry(ctx, storage, keyType, allowMigration)
	if err != nil {
		return nil, err
	}
	if entry == nil {
		return nil, nil
	}
	var keyEntry keyStorageEntry
	if err := entry.DecodeJSON(&keyEntry); err != nil {
		return nil, err
	}

	return &keyEntry, nil
}

func (b *backend) pathConfigCAUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	found, err := caKeysConfigured(ctx, req.Storage)
	if err != nil {
		return nil, err
	}
	if found {
		return logical.ErrorResponse("keys are already configured; delete them before reconfiguring"), nil
	}

	publicKey := data.Get("public_key").(string)
	privateKey := data.Get("private_key").(string)

	managedKeyName := data.Get("managed_key_name").(string)
	managedKeyID := data.Get("managed_key_id").(string)

	useManagedKey := managedKeyName != "" || managedKeyID != ""

	generateSigningKey := data.Get("generate_signing_key").(bool)

	if useManagedKey {
		generateSigningKey = false
		err = b.createManagedKey(ctx, req.Storage, managedKeyName, managedKeyID)
		if err != nil {
			return nil, err
		}
	} else {
		if publicKey != "" && privateKey != "" {
			_, err := ssh.ParsePrivateKey([]byte(privateKey))
			if err != nil {
				return logical.ErrorResponse(fmt.Sprintf("Unable to parse private_key as an SSH private key: %v", err)), nil
			}

			_, err = parsePublicSSHKey(publicKey)
			if err != nil {
				return logical.ErrorResponse(fmt.Sprintf("Unable to parse public_key as an SSH public key: %v", err)), nil
			}
		} else if generateSigningKey {
			keyType := data.Get("key_type").(string)
			keyBits := data.Get("key_bits").(int)

			publicKey, privateKey, err = generateSSHKeyPair(b.Backend.GetRandomReader(), keyType, keyBits)
			if err != nil {
				return nil, err
			}
		} else {
			return logical.ErrorResponse("if generate_signing_key is false, either both public_key and private_key or a managed key must be provided"), nil
		}

		err = createStoredKey(ctx, req.Storage, publicKey, privateKey)
		if err != nil {
			return nil, err
		}
	}

	if generateSigningKey {
		response := &logical.Response{
			Data: map[string]interface{}{
				"public_key": publicKey,
			},
		}

		return response, nil
	}

	return nil, nil
}

func createStoredKey(ctx context.Context, s logical.Storage, publicKey, privateKey string) error {
	if publicKey == "" || privateKey == "" {
		return fmt.Errorf("failed to generate or parse the keys")
	}

	entry, err := logical.StorageEntryJSON(caPublicKeyStoragePath, &keyStorageEntry{
		Key: publicKey,
	})
	if err != nil {
		return err
	}

	// Save the public key
	err = s.Put(ctx, entry)
	if err != nil {
		return err
	}

	entry, err = logical.StorageEntryJSON(caPrivateKeyStoragePath, &keyStorageEntry{
		Key: privateKey,
	})
	if err != nil {
		return err
	}

	// Save the private key
	err = s.Put(ctx, entry)
	if err != nil {
		var mErr *multierror.Error

		mErr = multierror.Append(mErr, fmt.Errorf("failed to store CA private key: %w", err))

		// If storing private key fails, the corresponding public key should be
		// removed
		if delErr := s.Delete(ctx, caPublicKeyStoragePath); delErr != nil {
			mErr = multierror.Append(mErr, fmt.Errorf("failed to cleanup CA public key: %w", delErr))
			return mErr
		}

		return err
	}

	return nil
}

func generateSSHKeyPair(randomSource io.Reader, keyType string, keyBits int) (string, string, error) {
	if randomSource == nil {
		randomSource = rand.Reader
	}

	var publicKey crypto.PublicKey
	var privateBlock *pem.Block

	switch keyType {
	case ssh.KeyAlgoRSA, "rsa":
		if keyBits == 0 {
			keyBits = 4096
		}

		if keyBits < 2048 {
			return "", "", fmt.Errorf("refusing to generate weak %v key: %v bits < 2048 bits", keyType, keyBits)
		}

		privateSeed, err := cryptoutil.GenerateRSAKey(randomSource, keyBits)
		if err != nil {
			return "", "", err
		}

		privateBlock = &pem.Block{
			Type:    "RSA PRIVATE KEY",
			Headers: nil,
			Bytes:   x509.MarshalPKCS1PrivateKey(privateSeed),
		}

		publicKey = privateSeed.Public()
	case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521, "ec":
		var curve elliptic.Curve
		switch keyType {
		case ssh.KeyAlgoECDSA256:
			curve = elliptic.P256()
		case ssh.KeyAlgoECDSA384:
			curve = elliptic.P384()
		case ssh.KeyAlgoECDSA521:
			curve = elliptic.P521()
		default:
			switch keyBits {
			case 0, 256:
				curve = elliptic.P256()
			case 384:
				curve = elliptic.P384()
			case 521:
				curve = elliptic.P521()
			default:
				return "", "", fmt.Errorf("unknown ECDSA key pair algorithm and bits: %v / %v", keyType, keyBits)
			}
		}

		privateSeed, err := ecdsa.GenerateKey(curve, randomSource)
		if err != nil {
			return "", "", err
		}

		marshalled, err := x509.MarshalECPrivateKey(privateSeed)
		if err != nil {
			return "", "", err
		}

		privateBlock = &pem.Block{
			Type:    "EC PRIVATE KEY",
			Headers: nil,
			Bytes:   marshalled,
		}

		publicKey = privateSeed.Public()
	case ssh.KeyAlgoED25519, "ed25519":
		_, privateSeed, err := ed25519.GenerateKey(randomSource)
		if err != nil {
			return "", "", err
		}

		marshalled := edkey.MarshalED25519PrivateKey(privateSeed)
		if marshalled == nil {
			return "", "", errors.New("unable to marshal ed25519 private key")
		}

		privateBlock = &pem.Block{
			Type:    "OPENSSH PRIVATE KEY",
			Headers: nil,
			Bytes:   marshalled,
		}

		publicKey = privateSeed.Public()
	default:
		return "", "", fmt.Errorf("unknown ssh key pair algorithm: %v", keyType)
	}

	public, err := ssh.NewPublicKey(publicKey)
	if err != nil {
		return "", "", err
	}

	return string(ssh.MarshalAuthorizedKey(public)), string(pem.EncodeToMemory(privateBlock)), nil
}

func (b *backend) createManagedKey(ctx context.Context, s logical.Storage, managedKeyName, managedKeyId string) error {
	var keyId managed_key.UUIDKey
	var keyName managed_key.NameKey
	var keyInfo *managed_key.ManagedKeyInfo
	var err error

	if managedKeyId != "" {
		keyId = managed_key.UUIDKey(managedKeyId)
		keyInfo, err = managed_key.GetManagedKeyInfo(ctx, b, keyId)
	} else if managedKeyName != "" {
		keyName = managed_key.NameKey(managedKeyName)
		keyInfo, err = managed_key.GetManagedKeyInfo(ctx, b, keyName)
	}

	if err != nil {
		return fmt.Errorf("error retrieving public key: %s", err)
	}

	entry, err := logical.StorageEntryJSON(caManagedKeyStoragePath, &managedKeyStorageEntry{
		PublicKey: string(ssh.MarshalAuthorizedKey(keyInfo.PublicKey())),
		KeyName:   keyInfo.Name,
		KeyId:     keyInfo.Uuid,
	})
	if err != nil {
		return fmt.Errorf("error creating storage entry: %s", err)
	}

	// Save the public key
	err = s.Put(ctx, entry)
	if err != nil {
		return fmt.Errorf("error writing key entry to storage: %s", err)
	}

	return nil
}

func getCAPublicKey(ctx context.Context, storage logical.Storage, allowMigration bool) (string, error) {
	var publicKey string

	storedKeyEntry, err := readStoredKey(ctx, storage, caPublicKey, allowMigration)
	if err != nil {
		return "", err
	}

	if storedKeyEntry == nil {
		managedKeyEntry, err := readManagedKey(ctx, storage)
		if err != nil {
			return "", err
		}

		if managedKeyEntry == nil {
			return "", nil
		}

		publicKey = managedKeyEntry.PublicKey
	} else {
		publicKey = storedKeyEntry.Key
	}

	return publicKey, nil
}

func readManagedKey(ctx context.Context, storage logical.Storage) (*managedKeyStorageEntry, error) {
	entry, err := storage.Get(ctx, caManagedKeyStoragePath)
	if err != nil {
		return nil, fmt.Errorf("failed to read CA key of type managed key: %w", err)
	}

	if entry == nil {
		return nil, nil
	}

	var keyEntry managedKeyStorageEntry
	if err := entry.DecodeJSON(&keyEntry); err != nil {
		return nil, err
	}

	return &keyEntry, nil
}

func caKeysConfigured(ctx context.Context, s logical.Storage) (bool, error) {
	const allowMigration = false // no need to allow migration when just checking for existence, we can do that later
	publicKeyEntry, err := readStoredKey(ctx, s, caPublicKey, allowMigration)
	if err != nil {
		return false, fmt.Errorf("failed to read CA public key: %w", err)
	}

	privateKeyEntry, err := readStoredKey(ctx, s, caPrivateKey, allowMigration)
	if err != nil {
		return false, fmt.Errorf("failed to read CA private key: %w", err)
	}

	if (publicKeyEntry != nil && publicKeyEntry.Key != "") || (privateKeyEntry != nil && privateKeyEntry.Key != "") {
		return true, nil
	}

	managedKeyEntry, err := readManagedKey(ctx, s)
	if err != nil {
		return false, fmt.Errorf("failed to read CA managed key: %w", err)
	}

	if managedKeyEntry != nil {
		return true, nil
	}

	return false, nil
}

// pathConfigCARecover recovers the CA from the target snapshot back to the live storage.
// ignore-nil-nil-function-check
func (b *backend) pathConfigCARecover(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	// check live storage for existing keys. Disallow recovery if CA is already configured for consistency with create operation
	found, err := caKeysConfigured(ctx, req.Storage)
	if err != nil {
		return nil, err
	}
	if found {
		return logical.ErrorResponse("keys are already configured; delete them before recovering the CA"), nil
	}

	// fetch directly from the snapshot storage instead of following the usual restore procedure of getting the values
	// from the req.Data, since those came from a previous CARead operation on the loaded snapshot, which only contains
	// the public key.
	snapshotStorage, err := logical.NewSnapshotStorageView(req)
	if err != nil {
		return nil, err
	}
	const allowMigration = false // prevent migration from deprecated paths as we can't allow writes on the snapshot storage
	publicKeyEntry, err := readStoredKeyEntry(ctx, snapshotStorage, caPublicKey, allowMigration)
	if err != nil {
		return nil, fmt.Errorf("failed to read CA public key for restore: %w", err)
	}
	privateKeyEntry, err := readStoredKeyEntry(ctx, snapshotStorage, caPrivateKey, allowMigration)
	if err != nil {
		return nil, fmt.Errorf("failed to read CA private key for restore: %w", err)
	}
	managedKey, err := readManagedKey(ctx, snapshotStorage)
	if err != nil {
		return nil, fmt.Errorf("failed to read CA managed key for restore: %w", err)
	}

	if publicKeyEntry == nil && privateKeyEntry == nil && managedKey == nil {
		return logical.ErrorResponse("no CA keys found in snapshot storage to restore"), nil
	}

	// it's possible that we've read the keys from a deprecated path in the snapshot, but it should be automatically
	// upgraded to the new path anyway, so we don't care about restoring it back to the deprecated path
	if publicKeyEntry != nil {
		err = req.Storage.Put(ctx, publicKeyEntry)
		if err != nil {
			return nil, fmt.Errorf("failed to restore public key entry in storage: %w", err)
		}
	}
	if privateKeyEntry != nil {
		err = req.Storage.Put(ctx, privateKeyEntry)
		if err != nil {
			return nil, fmt.Errorf("failed to restore private key entry in storage: %w", err)
		}
	}
	if managedKey != nil {
		err = b.createManagedKey(ctx, req.Storage, managedKey.KeyName.String(), managedKey.KeyId.String())
		if err != nil {
			return nil, fmt.Errorf("failed to restore managed key entry in storage: %w", err)
		}
	}

	return nil, nil
}
